Generative Adversarial Networks(GAN)

GAN is one of the areas in the Neural Networks with a very fast pace of reasearch. Every week there is new GAN. To explain the concept of GAN, let's use a small anecdote to stage this concept. In old movies to sketch a criminal there will be an artist and a witness. Witness tells artist some details and witness validates his art and says if it is correct or not. If the imageis not similar to the criminal, artist will redraw it again with further changes. This process will be repeated until artist produces an image which is accepted by the witness. In other words witness unable to differentiate the artists imaginary art from the crimial. At this point they stop.

GAN works similar to this idea. We have a generator network that generates random images and a Descriminator network that clssifies whether that image is fake or real. If the image is fake the descriminator discards the image and if image is real, it accepts it. This process continues until generator generates all real images. The generator is a decoder network from the autoencoder we discussed in the tutorial before. We take a random codeword and we pass it to the generator network to generate image. We take that generated image and feed it to descriminator to tell if it is a real or fake image. To achieve that we always keep our descriminator a step ahead.

The following code shows the implementation of GAN using YANN: For GAN in YANN we need to use the yann.special.gan package which has similar functionalities like a network.


In [1]:
from yann.special.gan import gan 
from theano import tensor as T 

def shallow_gan_mnist ( dataset= None, verbose = 1 ):
    """
    This function is a demo example of a generative adversarial network. 
    This is an example code. You should study this code rather than merely run it.  

    Args: 
        dataset: Supply a dataset.    
        verbose: Similar to the rest of the dataset.

    Notes:
        This method is setup for MNIST.
    """
    optimizer_params =  {        
                "momentum_type"       : 'polyak',             
                "momentum_params"     : (0.65, 0.9, 50),      
                "regularization"      : (0.000, 0.000),       
                "optimizer_type"      : 'rmsprop',                
                "id"                  : "main"
                        }


    dataset_params  = {
                            "dataset"   : dataset,
                            "type"      : 'xy',
                            "id"        : 'data'
                    }

    visualizer_params = {
                    "root"       : '.',
                    "frequency"  : 1,
                    "sample_size": 225,
                    "rgb_filters": False,
                    "debug_functions" : False,
                    "debug_layers": True,  
                    "id"         : 'main'
                        }  
                      
    # intitialize the network
    net = gan (      borrow = True,
                     verbose = verbose )                       
    
    net.add_module ( type = 'datastream', 
                     params = dataset_params,
                     verbose = verbose )    
    
    net.add_module ( type = 'visualizer',
                     params = visualizer_params,
                     verbose = verbose 
                    ) 

    #z - latent space created by random layer
    net.add_layer(type = 'random',
                        id = 'z',
                        num_neurons = (100,32), 
                        distribution = 'normal',
                        mu = 0,
                        sigma = 1,
                        verbose = verbose)
    
    #x - inputs come from dataset 1 X 784
    net.add_layer ( type = "input",
                    id = "x",
                    verbose = verbose, 
                    datastream_origin = 'data', # if you didnt add a dataset module, now is 
                                                 # the time. 
                    mean_subtract = False )

    net.add_layer ( type = "dot_product",
                    origin = "z",
                    id = "G(z)",
                    num_neurons = 784,
                    activation = 'tanh',
                    verbose = verbose
                    )  # This layer is the one that creates the images.
        
    #D(x) - Contains params theta_d creates features 1 X 800. 
    net.add_layer ( type = "dot_product",
                    id = "D(x)",
                    origin = "x",
                    num_neurons = 800,
                    activation = 'relu',
                    regularize = True,                                                         
                    verbose = verbose
                    )

    net.add_layer ( type = "dot_product",
                    id = "D(G(z))",
                    origin = "G(z)",
                    input_params = net.dropout_layers["D(x)"].params, 
                    num_neurons = 800,
                    activation = 'relu',
                    regularize = True,
                    verbose = verbose
                    )


    #C(D(x)) - This is the opposite of C(D(G(z))), real
    net.add_layer ( type = "dot_product",
                    id = "real",
                    origin = "D(x)",
                    num_neurons = 1,
                    activation = 'sigmoid',
                    verbose = verbose
                    )

    #C(D(G(z))) fake - the classifier for fake/real that always predicts fake 
    net.add_layer ( type = "dot_product",
                    id = "fake",
                    origin = "D(G(z))",
                    num_neurons = 1,
                    activation = 'sigmoid',
                    input_params = net.dropout_layers["real"].params, # Again share their parameters                    
                    verbose = verbose
                    )

    
    #C(D(x)) - This is the opposite of C(D(G(z))), real
    net.add_layer ( type = "classifier",
                    id = "softmax",
                    origin = "D(x)",
                    num_classes = 10,
                    activation = 'softmax',
                    verbose = verbose
                   )
    
    # objective layers 
    # discriminator objective 
    net.add_layer (type = "tensor",
                    input =  - 0.5 * T.mean(T.log(net.layers['real'].output)) - \
                                  0.5 * T.mean(T.log(1-net.layers['fake'].output)),
                    input_shape = (1,),
                    id = "discriminator_task"
                    )

    net.add_layer ( type = "objective",
                    id = "discriminator_obj",
                    origin = "discriminator_task",
                    layer_type = 'value',
                    objective = net.dropout_layers['discriminator_task'].output,
                    datastream_origin = 'data', 
                    verbose = verbose
                    )
    #generator objective 
    net.add_layer (type = "tensor",
                    input =  - 0.5 * T.mean(T.log(net.layers['fake'].output)),
                    input_shape = (1,),
                    id = "objective_task"
                    )
    net.add_layer ( type = "objective",
                    id = "generator_obj",
                    layer_type = 'value',
                    origin = "objective_task",
                    objective = net.dropout_layers['objective_task'].output,
                    datastream_origin = 'data', 
                    verbose = verbose
                    )   

    #softmax objective.    
    net.add_layer ( type = "objective",
                    id = "classifier_obj",
                    origin = "softmax",
                    objective = "nll",
                    layer_type = 'discriminator',
                    datastream_origin = 'data', 
                    verbose = verbose
                    )
    
    from yann.utils.graph import draw_network
    draw_network(net.graph, filename = 'gan.png')    
    net.pretty_print()
    
    net.cook (  objective_layers = ["classifier_obj", "discriminator_obj", "generator_obj"],
                optimizer_params = optimizer_params,
                discriminator_layers = ["D(x)"],
                generator_layers = ["G(z)"], 
                classifier_layers = ["D(x)", "softmax"],                                                
                softmax_layer = "softmax",
                game_layers = ("fake", "real"),
                verbose = verbose )
                    
    learning_rates = (0.05, 0.01 )  

    net.train( epochs = (20), 
               k = 2,  
               pre_train_discriminator = 3,
               validate_after_epochs = 1,
               visualize_after_epochs = 1,
               training_accuracy = True,
               show_progress = True,
               early_terminate = True,
               verbose = verbose)
                           
    return net

if __name__ == '__main__':
    
    from yann.special.datasets import cook_mnist_normalized_zero_mean as c 
    # from yann.special.datasets import cook_cifar10_normalized_zero_mean as c
    print " creating a new dataset to run through"
    data = c (verbose = 2)
    dataset = data.dataset_location() 

    net = shallow_gan_mnist ( dataset, verbose = 2 )


WARNING (theano.sandbox.cuda): The cuda backend is deprecated and will be removed in the next release (v0.10).  Please switch to the gpuarray backend. You can get more information about how to switch at this URL:
 https://github.com/Theano/Theano/wiki/Converting-to-the-new-gpu-back-end%28gpuarray%29

Using gpu device 0: GeForce GTX 750 Ti (CNMeM is enabled with initial size: 80.0% of memory, cuDNN 5110)
 creating a new dataset to run through
. Setting up dataset 
.. setting up skdata
... Importing mnist from skdata
.. setting up dataset
.. training data
.. validation data 
.. testing data 
. Dataset 60713 is created.
. Time taken is 0.890919 seconds
. Initializing the network
.. Setting up the datastream
.. Setting up the visualizer
.. Adding random layer z
.. Adding input layer x
.. Adding dot_product layer G(z)
.. Adding dot_product layer D(x)
.. Adding flatten layer 4
.. Adding dot_product layer D(G(z))
.. Adding dot_product layer real
.. Adding dot_product layer fake
.. Adding classifier layer softmax
.. Adding tensor layer discriminator_task
.. Adding objective layer discriminator_obj
.. Adding tensor layer objective_task
.. Adding objective layer generator_obj
.. Adding objective layer classifier_obj
.. Saving the network down as an image
.. This method will be deprecated with the implementation of a visualizer,also this works only for tree-like networks. This will cause errors in printing DAG-style networks.
 |-
 |-
 |-
 |- id: objective_task
 |-=================------------------
 |- type: tensor
 |- output shape: (1,)
 |------------------------------------
          |-
          |-
          |-
          |- id: generator_obj
          |-=================------------------
          |- type: objective
          |- output shape: (1,)
          |------------------------------------
 |-
 |-
 |-
 |- id: x
 |-=================------------------
 |- type: input
 |- output shape: (500, 1, 28, 28)
 |------------------------------------
          |-
          |-
          |-
          |- id: 4
          |-=================------------------
          |- type: flatten
          |- output shape: (500, 784)
          |------------------------------------
                   |-
                   |-
                   |-
                   |- id: D(x)
                   |-=================------------------
                   |- type: dot_product
                   |- output shape: (500, 800)
                   |- batch norm is OFF
                   |------------------------------------
                            |-
                            |-
                            |-
                            |- id: real
                            |-=================------------------
                            |- type: dot_product
                            |- output shape: (500, 1)
                            |- batch norm is OFF
                            |------------------------------------
                            |-
                            |-
                            |-
                            |- id: softmax
                            |-=================------------------
                            |- type: classifier
                            |- output shape: (500, 10)
                            |------------------------------------
                                     |-
                                     |-
                                     |-
                                     |- id: classifier_obj
                                     |-=================------------------
                                     |- type: objective
                                     |- output shape: (1,)
                                     |------------------------------------
 |-
 |-
 |-
 |- id: z
 |-=================------------------
 |- type: random
 |- output shape: (100, 32)
 |------------------------------------
          |-
          |-
          |-
          |- id: G(z)
          |-=================------------------
          |- type: dot_product
          |- output shape: (100, 784)
          |- batch norm is OFF
          |------------------------------------
          |        |-
          |        |-
          |        |-
          |        |- id: D(G(z))
          |        |-=================------------------
          |        |- type: dot_product
          |        |- output shape: (100, 800)
          |        |- batch norm is OFF
          |        |------------------------------------
          |                 |-
          |                 |-
          |                 |-
          |                 |- id: fake
          |                 |-=================------------------
          |                 |- type: dot_product
          |                 |- output shape: (100, 1)
          |                 |- batch norm is OFF
          |                 |------------------------------------
          |-
          |-
          |-
          |- id: G(z)
          |-=================------------------
          |- type: dot_product
          |- output shape: (100, 784)
          |- batch norm is OFF
          |------------------------------------
                   |-
                   |-
                   |-
                   |- id: D(G(z))
                   |-=================------------------
                   |- type: dot_product
                   |- output shape: (100, 800)
                   |- batch norm is OFF
                   |------------------------------------
                            |-
                            |-
                            |-
                            |- id: fake
                            |-=================------------------
                            |- type: dot_product
                            |- output shape: (100, 1)
                            |- batch norm is OFF
                            |------------------------------------
 |-
 |-
 |-
 |- id: discriminator_task
 |-=================------------------
 |- type: tensor
 |- output shape: (1,)
 |------------------------------------
          |-
          |-
          |-
          |- id: discriminator_obj
          |-=================------------------
          |- type: objective
          |- output shape: (1,)
          |------------------------------------
.. Cooking the network
.. Setting up the resultor
.. Setting up the optimizer
.. Setting up the optimizer
.. Setting up the optimizer
. Training
.


.. Pre-Training Epoch: 0
| training  100% Time: 0:00:00                                                 
| validation    0% ETA:  --:--:--                                              
.. Discriminator Softmax Cost       : 19.0017
| validation  100% Time: 0:00:00                                               
.. Validation accuracy : 88.69
.. Training accuracy : 87.728
.. Best training accuracy
.. Best validation accuracy
.


.. Pre-Training Epoch: 1
| training  100% Time: 0:00:00                                                 
| validation    0% ETA:  --:--:--                                              
.. Discriminator Softmax Cost       : 0.634408
| validation  100% Time: 0:00:00                                               
.. Validation accuracy : 89.11
.. Training accuracy : 87.636
.. Best validation accuracy
.


.. Pre-Training Epoch: 2
| training  100% Time: 0:00:00                                                 
| validation    0% ETA:  --:--:--                                              
.. Discriminator Softmax Cost       : 0.690765
| validation  100% Time: 0:00:00                                               
.. Validation accuracy : 90.52
.. Training accuracy : 89.722
.. Best training accuracy
.. Best validation accuracy
.. Pre- Training complete.Took 0.181362733333 minutes
.


.. Epoch: 0 Era: 0
| training  100% Time: 0:00:00                                                 
| validation    6% ETA:  0:00:01                                               
.. Discriminator Sigmoid D(x)   : 0.622721
.. Generator Sigmoid D(G(z))         : 0.422515
| validation  100% Time: 0:00:00                                               
.. Validation accuracy : 82.69
.. Training accuracy : 80.798
.


.. Epoch: 1 Era: 0
| training  100% Time: 0:00:00                                                 
| validation    0% ETA:  --:--:--                                              
.. Discriminator Sigmoid D(x)   : 0.584707
.. Generator Sigmoid D(G(z))         : 0.312661
| validation  100% Time: 0:00:00                                               
.. Validation accuracy : 76.8
.. Training accuracy : 75.728
.


.. Epoch: 2 Era: 0
| training  100% Time: 0:00:00                                                 
| validation    6% ETA:  0:00:01                                               
.. Discriminator Sigmoid D(x)   : 0.823963
.. Generator Sigmoid D(G(z))         : 0.57011
| validation  100% Time: 0:00:00                                               
.. Validation accuracy : 76.38
.. Training accuracy : 74.724
.


.. Epoch: 3 Era: 0
| training  100% Time: 0:00:00                                                 
| validation    0% ETA:  --:--:--                                              
.. Discriminator Sigmoid D(x)   : 0.572462
.. Generator Sigmoid D(G(z))         : 0.40043
| validation  100% Time: 0:00:00                                               
.. Validation accuracy : 77.75
.. Training accuracy : 76.034
.


.. Epoch: 4 Era: 0
| training  100% Time: 0:00:00                                                 
| validation    6% ETA:  0:00:01                                               
.. Discriminator Sigmoid D(x)   : 0.859632
.. Generator Sigmoid D(G(z))         : 0.686642
| validation  100% Time: 0:00:00                                               
.. Validation accuracy : 80.35
.. Training accuracy : 78.6
.


.. Epoch: 5 Era: 0
| training  100% Time: 0:00:00                                                 
| validation    0% ETA:  --:--:--                                              
.. Discriminator Sigmoid D(x)   : 0.527774
.. Generator Sigmoid D(G(z))         : 0.37337
| validation  100% Time: 0:00:00                                               
.. Validation accuracy : 83.96
.. Training accuracy : 82.402
.


.. Epoch: 6 Era: 0
| training  100% Time: 0:00:00                                                 
| validation    6% ETA:  0:00:01                                               
.. Discriminator Sigmoid D(x)   : 0.611699
.. Generator Sigmoid D(G(z))         : 0.455202
| validation  100% Time: 0:00:00                                               
.. Validation accuracy : 85.0
.. Training accuracy : 83.0
.


.. Epoch: 7 Era: 0
| training  100% ETA:  0:00:00                                                 
.. Discriminator Sigmoid D(x)   : 0.743451
.. Generator Sigmoid D(G(z))         : 0.64773
| training  100% Time: 0:00:00                                                 
| validation  100% Time: 0:00:00                                               
.. Validation accuracy : 78.24
.. Training accuracy : 76.71
.


.. Epoch: 8 Era: 0
| training  100% Time: 0:00:00                                                 
| validation    0% ETA:  --:--:--                                              
.. Discriminator Sigmoid D(x)   : 0.575549
.. Generator Sigmoid D(G(z))         : 0.406572
| validation  100% Time: 0:00:00                                               
.. Validation accuracy : 82.22
.. Training accuracy : 81.058
.


.. Epoch: 9 Era: 0
| training  100% Time: 0:00:00                                                 
| validation   33% ETA:  0:00:00                                               
.. Discriminator Sigmoid D(x)   : 0.342317
.. Generator Sigmoid D(G(z))         : 0.287621
| validation  100% Time: 0:00:00                                               
.. Validation accuracy : 76.24
.. Training accuracy : 74.368
.


.. Epoch: 10 Era: 0
| training  100% Time: 0:00:00                                                 
| validation    0% ETA:  --:--:--                                              
.. Discriminator Sigmoid D(x)   : 0.533018
.. Generator Sigmoid D(G(z))         : 0.373018
| validation  100% Time: 0:00:00                                               
.. Validation accuracy : 86.94
.. Training accuracy : 85.57
.


.. Epoch: 11 Era: 0
| training  100% Time: 0:00:00                                                 
| validation    0% ETA:  --:--:--                                              
.. Discriminator Sigmoid D(x)   : 0.627197
.. Generator Sigmoid D(G(z))         : 0.529224
| validation  100% Time: 0:00:00                                               
.. Validation accuracy : 89.11
.. Training accuracy : 87.852
.


.. Epoch: 12 Era: 0
| training  100% Time: 0:00:00                                                 
| validation    0% ETA:  --:--:--                                              
.. Discriminator Sigmoid D(x)   : 0.509487
.. Generator Sigmoid D(G(z))         : 0.388974
| validation  100% Time: 0:00:00                                               
.. Validation accuracy : 90.37
.. Training accuracy : 89.298
.


.. Epoch: 13 Era: 0
| training  100% Time: 0:00:00                                                 
/ validation   35% ETA:  0:00:00                                               
.. Discriminator Sigmoid D(x)   : 0.546393
.. Generator Sigmoid D(G(z))         : 0.453736
| validation  100% Time: 0:00:00                                               
.. Validation accuracy : 87.47
.. Training accuracy : 86.172
.


.. Epoch: 14 Era: 0
| training  100% Time: 0:00:00                                                 
| validation    6% ETA:  0:00:01                                               
.. Discriminator Sigmoid D(x)   : 0.455782
.. Generator Sigmoid D(G(z))         : 0.345936
| validation  100% Time: 0:00:00                                               
.. Validation accuracy : 87.7
.. Training accuracy : 86.466
.


.. Epoch: 15 Era: 0
| training  100% Time: 0:00:00                                                 
| validation    0% ETA:  --:--:--                                              
.. Discriminator Sigmoid D(x)   : 0.479013
.. Generator Sigmoid D(G(z))         : 0.40071
| validation  100% Time: 0:00:00                                               
.. Validation accuracy : 87.43
.. Training accuracy : 86.022
.


.. Epoch: 16 Era: 0
| training  100% Time: 0:00:00                                                 
| validation    0% ETA:  --:--:--                                              
.. Discriminator Sigmoid D(x)   : 0.529526
.. Generator Sigmoid D(G(z))         : 0.427006
| validation  100% Time: 0:00:00                                               
.. Validation accuracy : 87.65
.. Training accuracy : 86.724
.


.. Epoch: 17 Era: 0
| training  100% Time: 0:00:00                                                 
/ validation   35% ETA:  0:00:00                                               
.. Discriminator Sigmoid D(x)   : 0.507176
.. Generator Sigmoid D(G(z))         : 0.410834
| validation  100% Time: 0:00:00                                               
.. Validation accuracy : 88.23
.. Training accuracy : 87.546
.


.. Epoch: 18 Era: 0
| training  100% Time: 0:00:00                                                 
| validation    0% ETA:  --:--:--                                              
.. Discriminator Sigmoid D(x)   : 0.647033
.. Generator Sigmoid D(G(z))         : 0.546981
| validation  100% Time: 0:00:00                                               
.. Validation accuracy : 90.56
.. Training accuracy : 89.936
.. Best training accuracy
.. Best validation accuracy
.


.. Epoch: 19 Era: 0
| training  100% Time: 0:00:00                                                 
/ validation   35% ETA:  0:00:00                                               
.. Discriminator Sigmoid D(x)   : 0.523178
.. Generator Sigmoid D(G(z))         : 0.436908
| validation  100% Time: 0:00:00                                               
.. Validation accuracy : 89.79
.. Training accuracy : 88.632
.. Training complete.Took 1.48964608333 minutes